{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center> Vanilla Generative Adversarial Network for NLS-KDD <center/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "from preprocessing import *\n",
    "from classifiers import *\n",
    "from utils import *\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, Input, Dropout\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras import backend as K\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read Data & Standard Scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train,test, label_mapping = get_data(encoding=\"Label\")\n",
    "data_cols = list(train.columns[ train.columns != 'label' ])\n",
    "x_train , x_test = preprocess(train,test,data_cols,\"Robust\",True)\n",
    "\n",
    "y_train = x_train.label.values\n",
    "y_test = x_test.label.values\n",
    "\n",
    "data_cols = list(x_train.columns[ x_train.columns != 'label' ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Generator, Descriminator & Full Generative Adversarial Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_discriminator(data_dim, min_num_neurones):\n",
    "    model = tf.keras.models.Sequential(name='Discriminator')\n",
    "    \n",
    "    model.add(Dense(min_num_neurones*2, activation='relu',input_dim = data_dim ))\n",
    "    model.add(Dense(min_num_neurones, activation='relu'))\n",
    "    model.add(Dense(1, activation='sigmoid'))\n",
    "    \n",
    "    model.compile(loss='binary_crossentropy', optimizer=\"sgd\")\n",
    "    \n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_generator(data_dim, min_num_neurones,noise_dim):\n",
    "    \n",
    "    model = tf.keras.models.Sequential(name='Generator')\n",
    "    \n",
    "    model.add(Dense(min_num_neurones, activation='relu',input_dim = noise_dim ))\n",
    "    model.add(Dense(min_num_neurones*2, activation='relu'))\n",
    "    model.add(Dense(min_num_neurones*4, activation='tanh'))\n",
    "    \n",
    "    model.add(Dense(data_dim))\n",
    "    \n",
    "    model.compile(loss='binary_crossentropy', optimizer=\"sgd\")\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gan(discriminator, generator, z_dim):\n",
    "    discriminator.trainable=False\n",
    "    gan_input = Input(shape=(z_dim,))\n",
    "    x = generator(gan_input)\n",
    "    gan_output= discriminator(x)\n",
    "    \n",
    "    #gan_output = discriminator(generator(gan_input))\n",
    "    gan= Model(inputs = gan_input, outputs = gan_output)\n",
    "    gan.compile(loss='binary_crossentropy', optimizer='sgd')\n",
    "    return gan"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define batch generation & GAN training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batch(X, batch_size=1):\n",
    "    \"\"\"\n",
    "    Parameters:\n",
    "    -----------\n",
    "    X : ndarray\n",
    "        The input data to sample a into batch\n",
    "    size : int (default = 1)\n",
    "        Batch size\n",
    "\n",
    "    Return Value: ndarray - random choice of samples from the input X of batch_size\n",
    "    \"\"\"\n",
    "    batch_ix = np.random.choice(len(X), batch_size, replace=False)\n",
    "    return X[batch_ix]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(arguments,X):\n",
    "    \n",
    "    [rand_noise_dim, nb_steps, batch_size,D_epochs, G_epochs, min_num_neurones] = arguments\n",
    "    \n",
    "    data_dim = X.shape[1]\n",
    "    combined_loss, disc_loss_generated, disc_loss_real = [], [], []\n",
    "    \n",
    "    # Creating GAN\n",
    "    generator = create_generator(data_dim,min_num_neurones,rand_noise_dim)\n",
    "    discriminator = create_discriminator(data_dim,min_num_neurones)\n",
    "    adversarial_model = create_gan(discriminator, generator,rand_noise_dim)\n",
    "    \n",
    "    #Start training\n",
    "    for epoch in range(1,nb_steps + 1 ):\n",
    "        K.set_learning_phase(1)\n",
    "        \n",
    "        #Train Discriminator\n",
    "        discriminator.trainable=True\n",
    "        for i in range(D_epochs):\n",
    "            np.random.seed(i+epoch)\n",
    "        \n",
    "            noise = np.random.normal(0,1, size=(batch_size, rand_dim))\n",
    "            generated_samples = generator.predict(noise)\n",
    "            real_samples = get_batch(X,batch_size)\n",
    "            \n",
    "            d_l_r = discriminator.train_on_batch(real_samples, np.random.uniform(low=0.999, high=1.0, size=batch_size))\n",
    "            d_l_g = discriminator.train_on_batch(generated_samples, np.random.uniform(low=0.0, high=0.0001, size=batch_size))\n",
    "        \n",
    "        #Freeze Discriminator\n",
    "        discriminator.trainable = False\n",
    "        disc_loss_generated.append(d_l_g)\n",
    "        disc_loss_real.append(d_l_r)\n",
    "        \n",
    "        #Train Generator\n",
    "        for i in range(G_epochs):\n",
    "            np.random.seed(i+epoch)\n",
    "            \n",
    "            noise = np.random.normal(0,1, size = (batch_size, rand_dim))\n",
    "            loss = adversarial_model.train_on_batch(noise, np.random.uniform(low=0.999, high=1.0, size=batch_size))\n",
    "            \n",
    "        combined_loss.append(loss)\n",
    "        \n",
    "        #Do checkpointing\n",
    "        if epoch % 10 == 0:\n",
    "            K.set_learning_phase(0)\n",
    "            test_size = len(X)\n",
    "\n",
    "            z = np.random.normal(3,2,size=(test_size, rand_dim))\n",
    "            g_z = generator.predict(z)\n",
    "            \n",
    "            '''\n",
    "            p = norm.pdf(X.T)\n",
    "            q = norm.pdf(g_z.T)\n",
    "\n",
    "            norm_p = p/p.sum(axis=1,keepdims=1)\n",
    "            norm_q = q/q.sum(axis=1,keepdims=1)\n",
    "\n",
    "            tf_kl = kullback_leibler_divergence(tf.convert_to_tensor(norm_p, np.float32), tf.convert_to_tensor(norm_q, np.float32))\n",
    "            with tf.Session() as sess:\n",
    "                print(\"Tensorflow kullback_leibler_divergence : {}\".format(round(sum(sess.run(tf_kl)))))\n",
    "\n",
    "            print(\"Ephoc : {} ,Loss on fake: {}, Loss on real : {}\".format(epoch,d_l_g, d_l_r))\n",
    "            '''\n",
    "            fake_pred = np.array(adversarial_model.predict(z)).ravel()\n",
    "            real_pred = np.array(discriminator.predict(X)).ravel()\n",
    "\n",
    "            modelAccuracy(fake_pred,real_pred)\n",
    "        \n",
    "    return dict({\"generator_model\":generator,\"discriminator_model\":discriminator,\\\n",
    "            \"combined_model\":adversarial_model,\"generator_loss\":combined_loss,\\\n",
    "            \"disc_loss_generated\":disc_loss_generated,\"disc_loss_real\": disc_loss_real})\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter Train samples and set training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "K.clear_session()\n",
    "#Generative Adversarial Networks\n",
    "att_ind = np.where(y_train == label_mapping[\"probe\"])[0]\n",
    "\n",
    "x = x_train[data_cols].values[att_ind]\n",
    "n_to_generate = 2000\n",
    "\n",
    "rand_dim = 32\n",
    "base_n_count = 100\n",
    "\n",
    "combined_ep = 1000\n",
    "batch_size = 128 if len(x) > 128 else len(x)\n",
    "\n",
    "ep_d = 1\n",
    "ep_g = 2\n",
    "learning_rate = 0.0001#5e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Discriminator accuracy on Fake : 0.04855875831485588, Real : 0.9216612655637045\n",
      "Discriminator accuracy on Fake : 0.0961794303257718, Real : 0.9250554323725055\n",
      "Discriminator accuracy on Fake : 0.8928534879754392, Real : 0.922667576326113\n",
      "Discriminator accuracy on Fake : 0.8955824663141736, Real : 0.9191881289442265\n",
      "Discriminator accuracy on Fake : 0.22681221217806583, Real : 0.904929217124339\n",
      "Discriminator accuracy on Fake : 0.0016714992324748422, Real : 0.8680027289783387\n",
      "Discriminator accuracy on Fake : 0.0003922906361930752, Real : 0.65650690772642\n",
      "Discriminator accuracy on Fake : 0.1783728466655296, Real : 0.5787651373017226\n",
      "Discriminator accuracy on Fake : 0.7988060719768036, Real : 0.4528739553129797\n",
      "Discriminator accuracy on Fake : 0.999420092103019, Real : 0.0863721644209449\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.33083745522769914\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.44891693672181476\n",
      "Discriminator accuracy on Fake : 0.9999829438853829, Real : 0.548575814429473\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6061231451475354\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6075217465461368\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.5910967081698789\n",
      "Discriminator accuracy on Fake : 0.9999829438853829, Real : 0.6196315879242709\n",
      "Discriminator accuracy on Fake : 0.9997100460515095, Real : 0.6303428279038035\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6305475012792086\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.644806413099096\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7014156575132184\n",
      "Discriminator accuracy on Fake : 0.9967422821081358, Real : 0.7203479447381886\n",
      "Discriminator accuracy on Fake : 0.9836772983114447, Real : 0.7046392631758486\n",
      "Discriminator accuracy on Fake : 0.9999829438853829, Real : 0.6755415316390926\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6320825515947467\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6238785604639263\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6698618454716015\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7356302234351015\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7598328500767525\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7718232986525669\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7761384956506908\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7768889646938427\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7826709875490363\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7867985672863722\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.787702541361078\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7855364148047075\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7869179600886917\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7735800784581273\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7566604127579737\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7484564216271533\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.77127750298482\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.8103360054579567\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.8421797714480641\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.864642674398772\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.8786286883847859\n",
      "Discriminator accuracy on Fake : 0.9999488316561488, Real : 0.8851100119392802\n",
      "Discriminator accuracy on Fake : 0.9947978850417875, Real : 0.8891693672181477\n",
      "Discriminator accuracy on Fake : 0.7013644891693672, Real : 0.8836431860822105\n",
      "Discriminator accuracy on Fake : 0.0560293365171414, Real : 0.883165614872932\n",
      "Discriminator accuracy on Fake : 0.003957018591164933, Real : 0.8693160498038547\n",
      "Discriminator accuracy on Fake : 0.08789015862186594, Real : 0.828927170390585\n",
      "Discriminator accuracy on Fake : 0.9931775541531639, Real : 0.7722497015179942\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.780146682585707\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7934163397578031\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7977485928705441\n",
      "Discriminator accuracy on Fake : 0.9988572403206549, Real : 0.7965717209619648\n",
      "Discriminator accuracy on Fake : 0.002183182670987549, Real : 0.7677809994883166\n",
      "Discriminator accuracy on Fake : 0.0, Real : 0.7119733924611973\n",
      "Discriminator accuracy on Fake : 0.8198874296435272, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.0047757120927852635\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.12097902097902098\n",
      "Discriminator accuracy on Fake : 0.9999658877707658, Real : 0.2129455909943715\n",
      "Discriminator accuracy on Fake : 0.998294388538291, Real : 0.6185911649326283\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.5704588094831997\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6136278355790551\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.642708511001194\n",
      "Discriminator accuracy on Fake : 0.23527204502814258, Real : 0.6363636363636364\n",
      "Discriminator accuracy on Fake : 0.04095173119563363, Real : 0.619836261299676\n",
      "Discriminator accuracy on Fake : 0.9999829438853829, Real : 0.513252601057479\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.056933310591847176\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.06993006993006994\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.08447893569844789\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.09570185911649326\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.18036841207572915\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.30844277673545967\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6468360907385298\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7528398430837455\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7880436636534197\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7795326624594917\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7385297629200068\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.6097219853317414\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.32432201944397065\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.3182670987549036\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.3242537949855023\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.3269316049803855\n",
      "Discriminator accuracy on Fake : 0.9990789698106771, Real : 0.34908749786798565\n",
      "Discriminator accuracy on Fake : 0.15341975098072658, Real : 0.8583830803342999\n",
      "Discriminator accuracy on Fake : 0.574859287054409, Real : 0.7607027119222242\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7176530786286884\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7304451645915061\n",
      "Discriminator accuracy on Fake : 1.0, Real : 0.7429814088350674\n",
      "Discriminator accuracy on Fake : 0.525140712945591, Real : 0.7387003240661777\n",
      "Discriminator accuracy on Fake : 0.0009380863039399625, Real : 0.7123997953266246\n",
      "Discriminator accuracy on Fake : 0.5882312809142077, Real : 0.6977485928705441\n",
      "Discriminator accuracy on Fake : 0.12427085110011939, Real : 0.6752004093467509\n"
     ]
    }
   ],
   "source": [
    "arguments = [rand_dim, combined_ep, batch_size, ep_d,ep_g, base_n_count]\n",
    "res = training(arguments,x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_4 (Dense)              (None, 200)               7000      \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 100)               20100     \n",
      "_________________________________________________________________\n",
      "dense_6 (Dense)              (None, 1)                 101       \n",
      "=================================================================\n",
      "WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
      "Total params: 54,402\n",
      "Trainable params: 27,201\n",
      "Non-trainable params: 27,201\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "res[\"discriminator_model\"].summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
